import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
# from cplex import Cplex
from gurobipy import *
from ortools.graph.python import min_cost_flow

def Round(df, centers, color_flag, phi):
    
    data_center_dist = cdist(df.values, centers)
    data_center_cost = data_center_dist * data_center_dist
    
    rounded_cost = 0
    
    for var in color_flag:
        for color in range(min(color_flag[var]), max(color_flag[var])+1):
            # labels = np.zeros(df.shape[0])
            # clients_belong_color is a vector with the same size with clients
            # if client_belong_color[i] = 1, that means the i th client belong to this color
            clients_belong_color = [i == color for i in color_flag[var]]
            clients_belong_color_index = np.nonzero(clients_belong_color)[0]
            
            color_phi = phi[clients_belong_color_index]
            color_cost_matrix = data_center_cost[clients_belong_color_index]
            
            res = rounding_color(color_cost_matrix, color_phi)
            rounded_color_phi = res['assignment']
            rounded_cost += res['objective']
            phi[clients_belong_color_index] = rounded_color_phi
            
    return rounded_cost, phi

def Round_new(df, centers, color_flag, phi):
    
    data_center_dist = cdist(df.values, centers)
    data_center_cost = data_center_dist * data_center_dist
    
    rounded_cost = 0
    
    n = df.shape[0]     # number of client points
    k = centers.shape[0]    # number of centers
    m = 0   # number of colors
    for var in color_flag:
        for color in range(min(color_flag[var]), max(color_flag[var])+1):
            m = m + 1
    # there are 8 colomns of nodes:
    # sink, points, dummy points, colored_centers, dummy_color_centers, centers, dummy_centers, terminal
    
    # edges between sink to points
    
    center_weight = phi.sum(0)
    
    # sink number is 0
    sink_points_start = np.zeros(n)
    sink_points_end = np.array(range(1, n + 1))
    sink_points_lower = np.ones(n)
    sink_points_upper = np.ones(n)        
    sink_points_capacities = sink_points_upper - sink_points_lower
    
    # edges between points to colored centers
    # initialize the edge matrix
    points_colored_centers_start = np.zeros(n * k * m)
    points_colored_centers_end = np.zeros_like(points_colored_centers_start)
    points_colored_centers_lower = np.zeros_like(points_colored_centers_start)
    points_colored_centers_upper = np.ones_like(points_colored_centers_start)
    points_colored_centers_capacities = points_colored_centers_upper - points_colored_centers_lower
    
    # edges between colored_centers to centers
    colored_centers_centers_start = np.array(range(n + 1, n + 1 + k * m))
    colored_centers_centers_end = np.repeat(np.array(range(n + 1 + k * m, n + 1 + k * m + k))[None, :], m, axis=0).reshape(-1)
    colored_centers_centers_lower = np.zeros_like(colored_centers_centers_start)    # initialization
    colored_centers_centers_upper = np.zeros_like(colored_centers_centers_start)    # initialization
    colored_centers_centers_capacities = colored_centers_centers_upper - colored_centers_centers_lower
    
    # edges between centers to sink
    centers_sink_start = np.array(range(n + 1 + k * m, n + 1 + k * m + k))
    centers_sink_end = np.zeros_like(centers_sink_start)
    centers_sink_lower = np.floor(center_weight)
    centers_sink_upper = np.ceil(center_weight)
    centers_sink_capacities = centers_sink_upper - centers_sink_lower
    
    suppliers = np.zeros(1 + n + k * m + k)
    
    suppliers[0] = 0
    suppliers[1:n+1] = np.ones(n)
    suppliers[n+1: n + 1 + k * m] = -colored_centers_centers_lower
    suppliers[n + 1 + k * m: n + 1 + k * m + k] = colored_centers_centers_lower.reshape(m, k).sum(0) - centers_sink_lower
    
    unit_cost = np.zeros(n + n * (k * m) + k * m + k)
    unit_cost[1: 1 + n * (k * m)] = np.repeat(data_center_cost, m, axis=0).reshape(-1)
    
    color_index = 0
    curser = 0
    for var in color_flag:
        for color in range(min(color_flag[var]), max(color_flag[var])+1):
            # labels = np.zeros(df.shape[0])
            # clients_belong_color is a vector with the same size with clients
            # if client_belong_color[i] = 1, that means the i th client belong to this color
            clients_belong_color = [i == color for i in color_flag[var]]
            clients_belong_color_index = np.nonzero(clients_belong_color)[0]
            
            clients_belong_color_index = clients_belong_color_index + 1
            
            points_colored_centers_start[curser: curser + clients_belong_color_index.shape[0] * k] = (clients_belong_color_index[None, :].T @ np.ones([1,k])).reshape(-1)
            points_colored_centers_end[curser: curser + clients_belong_color_index.shape[0] * k] = np.repeat(np.arange(n + 1 + color_index * k, n + 1 + color_index * k + k)[None: ], clients_belong_color_index.shape[0], axis=0).reshape(-1)
            
            colored_phi = phi[clients_belong_color_index - 1, :]
            colored_center_weight = colored_phi.sum(0)
            
            colored_centers_centers_lower[color_index * k: color_index * k + k] = np.floor(colored_center_weight)
            colored_centers_centers_upper[color_index * k: color_index * k + k] = np.ceil(colored_center_weight)
            
            curser = curser + clients_belong_color_index.shape[0] * k
            color_index = color_index + 1
    
    """MinCostFlow simple interface example."""
    # Instantiate a SimpleMinCostFlow solver.
    smcf = min_cost_flow.SimpleMinCostFlow()
    
    # Add arcs, capacities and costs in bulk using numpy.
    all_starts = np.concatenate([sink_points_start, points_colored_centers_start, colored_centers_centers_start, centers_sink_start], axis=0).astype('int32')
    all_ends = np.concatenate([sink_points_end, points_colored_centers_end, colored_centers_centers_end, centers_sink_end], axis=0).astype('int32')
    all_capacities = np.concatenate([sink_points_capacities, points_colored_centers_capacities, colored_centers_centers_capacities, centers_sink_capacities], axis=0).astype('int32')
    all_costs = unit_cost.astype('int32')
    
    all_arcs = smcf.add_arcs_with_capacity_and_unit_cost(
        all_starts, all_ends, all_capacities, all_costs
    )

    # Add supply for each nodes.
    smcf.set_nodes_supplies(np.arange(0, suppliers.shape[0]), suppliers)

    # Find the min cost flow.
    status = smcf.solve()
    
    if status != smcf.OPTIMAL:
        print("There was an issue with the min cost flow input.")
        print(f"Status: {status}")
        exit(1)
    print(f"Minimum cost: {smcf.optimal_cost()}")
    print("")
    print(" Arc    Flow / Capacity Cost")
    solution_flows = smcf.flows(all_arcs)
    costs = solution_flows * unit_cost
    for arc, flow, cost in zip(all_arcs, solution_flows, costs):
        print(
            f"{smcf.tail(arc):1} -> {smcf.head(arc)}  {flow:3}  / {smcf.capacity(arc):3}       {cost}"
        )
            
    return costs, solution_flows

def rounding_color(color_cost_matrix, color_phi):
    
    # Gamma is the weight vector of centers
    gamma = color_phi.sum(0).reshape([-1, 1])
    floor_gamma = np.floor(gamma)
    fractional_gamma = gamma - floor_gamma
    
    ratio_floor = floor_gamma / gamma
    ratio_floor_phi = ratio_floor @ np.ones([ratio_floor.shape[1], color_phi.shape[1]])
    
    # The transmission matrix from center to clients
    # reversed_phi = color_phi.T
    reversed_cost = color_cost_matrix.T
    k, n = reversed_cost.shape
    
    # floor_phi = ratio_floor_phi * reversed_phi
    # fractional_phi = reversed_phi - floor_phi
    
    problem = Model('mip')
    integral_x = problem.addVars(k, n, vtype = GRB.BINARY, ub = 1)
    fractional_x = problem.addVars(k, n, vtype = GRB.BINARY, ub = 1)
    
    problem.setObjective(sum(integral_x[(j, i)]* reversed_cost[j, i] for i in range(n) for j in range(k)) + sum(fractional_x[(j, i)]* reversed_cost[j, i] for i in range(n) for j in range(k)), GRB.MINIMIZE)
    
    problem.addConstrs(sum(integral_x[(j, i)] for i in range(n)) == floor_gamma[j] for j in range(k))
    problem.addConstrs(sum(fractional_x[(j, i)] for i in range(n)) <= 1 for j in range(k))
    problem.addConstrs(sum(integral_x[(j, i)] + fractional_x[(j, i)] for j in range(k)) == 1 for i in range(n))
    
    # close the output
    problem.setParam('outPutFlag',0)
    
    # optimize the model
    problem.optimize()
    
    res = {
            "status": problem.Status,
            "objective": problem.ObjVal,
            'assignment': (np.array(list(problem.getAttr('x', integral_x).values())).reshape(k, n) + np.array(list(problem.getAttr('x', fractional_x).values())).reshape(k, n)).T
        }
    
    return res
    